import mpctools as mpc
import casadi as cs
import numpy as np
from scipy import integrate


class ModelHelper:
    def __init__(self, nx_list, nu_list, xscale=None, uscale=None, xss=None, uss=None):
        self.Nx_up = nx_list[0]
        self.Nx = sum(nx_list)
        self.Nu_up = nu_list[0]
        self.Nu = sum(nu_list)
        self.xscale = xscale
        self.uscale = uscale
        self.xss = xss
        self.uss = uss

    def F(self, x, u, w=None):
        """Parameters"""
        q = 100.0  # inlet flow rate, L/min
        Tin = 350.0  # inlet temperature, K
        Cin = 1.0  # inlet concentration, kmol/m^3
        V = 100.0  # Volume of reactor, L
        # r = 0.219				# radius of reactor, m
        k0 = 7.2E10  # rate constant, 1/min
        EoR = 8750.0  # Activate energy/gas constant, K
        UA = 5.0E4  # heat transfer coefficient, J/min.K
        rho = 1000.0  # density, g/L
        Cp = 0.239  # heat capacity, J/g.K
        DH = -5.0E4  # heat of reaction, J/mol
        T = 0.1  # discretization step size, min

        """Initialize ode list"""
        F = []

        """Fs"""
        F += [x[0] + T*((q/V)*(Cin + w[0] - x[0]) - k0*np.exp(-EoR/x[1])*x[0])]  # Concentration of A (mol/L)
        F += [x[1] + T*(((q/V)*(Tin + w[1] - x[1]) + (-DH/(rho*Cp))*k0*np.exp(-EoR/x[1])*x[0] + (UA/(V*rho*Cp))*(-x[1]))
                        + (UA/(V*rho*Cp))*u)]  # Temperature (K)
        return np.array(F)

    def F_scale(self, x, u):
        return cs.vertcat(*self.F(x * cs.DM(self.xscale), u * cs.DM(self.uscale))) / cs.DM(self.xscale)

    def h(self, states, actions):
        reward = 1 - states[-14] / (states[18]+1e-8) + actions['switch']*-0.5
        return reward

    def getODE(self, x, u, w=None):
        if w == None:
            w = np.zeros(self.Nx)
        """Parameters"""
        q = 100.0  # inlet flow rate, L/min
        Tin = 350.0  # inlet temperature, K
        Cin = 1.0  # inlet concentration, kmol/m^3
        V = 100.0  # Volume of reactor, L
        # r = 0.219				# radius of reactor, m
        k0 = 7.2E10  # rate constant, 1/min
        EoR = 8750.0  # Activate energy/gas constant, K
        UA = 5.0E4  # heat transfer coefficient, J/min.K
        rho = 1000.0  # density, g/L
        Cp = 0.239  # heat capacity, J/g.K
        DH = -5.0E4  # heat of reaction, J/mol

        """Initialize ode list"""
        F = []

        """Fs"""
        F += [(q/V)*(Cin + w[0] - x[0]) - k0*np.exp(-EoR/x[1])*x[0]]  # Concentration of A (mol/L)
        F += [((q/V)*(Tin + w[1] - x[1]) + (-DH/(rho*Cp))*k0*np.exp(-EoR/x[1])*x[0] + (UA/(V*rho*Cp))*(-x[1]))
                        + (UA/(V*rho*Cp))*u]  # Temperature (K)
        return np.array(F)

    def ss_optimization(self, nx, nu, lecost, x0, u0):
        """
        The function for solving a steady state optimization of a continuous-time system.
        The function uses Casadi symbolic syntax to solve the nonlinear problem (nlp).
        :param nx:
        :param nu:
        :param lecost:
        :param x0:
        :param u0:
        :return:
        """
        # todo: try to make it generic so we can use it in the future without too many modifications
        # Create symbolic variables for all arguments of getODE and getALG
        x = cs.SX.sym('x', nx)
        u = cs.SX.sym('u', nu)

        # Define the metric function as the objective in the optimization. Comment it if it's imported
        # def lecost(x, u):
        #     xx = x * self.xscale
        #     uu = u * self.uscale
        #     return - xx[6] * uu[1] - x[14] * uu[3]

        # Define the elements in optimization problem
        # 1. Decision variables
        dvar = cs.vertcat(x, u)       # put all arguments defined before together

        # 2. Objective functions / cost functions
        le = lecost(x, u)            # Define the cost function we want to minimize

        # 3. Equality and or inequality constraints
        xdot = cs.vertcat(self.getODE(x, u))
        gdot = cs.vertcat(xdot)   # The constraints need to be satisfied
        # CSTR - CIS specific -------------------------------------------------------------
        from model.env_rl_helper import RLModelHelper
        env = RLModelHelper()
        hrepABig = env.hrepABig
        hrepBBig = env.hrepBBig
        tol = env.tol
        extra_constraints = cs.vertcat(np.matmul(hrepABig, x))
        gdot = cs.vertcat(gdot, extra_constraints)
        glcon = cs.DM(np.zeros(nx+extra_constraints.size()[0]))
        glcon[nx:] = -cs.inf
        gucon = cs.DM(np.zeros(nx+extra_constraints.size()[0]))
        gucon[nx:] = hrepBBig + tol
        # CSTR - CIS specific: end --------------------------------------------------------

        # 4. Define upper and lower bounds of all symbolic variables to make sure they satisfy the physical constraints
        # todo: may get ub and lb from env_rl_helper
        # States constraints
        xlcon = cs.DM(np.ones(nx)) * np.array((0.0, 345.0))
        xucon = cs.DM(np.ones(nx)) * np.array((1.0, 355.0))
        # Actions constraints
        ulcon = cs.DM(np.ones(nu)) * 285.0
        uucon = cs.DM(np.ones(nu)) * 315.0
        # Decision variables constraints - include all lb and ub defined before
        dvarlcon = cs.vertcat(xlcon, ulcon)
        dvarucon = cs.vertcat(xucon, uucon)

        # Creat optimization problem - minimizing f by changing x while satisfying g
        prob = {'x': dvar, 'f': le, 'g': gdot}

        # NLP solver options
        opts = {}
        opts["expand"] = True
        # uncomment this line to let ipopt suppress any outputs
        # opts["ipopt.print_level"] = 0
        opts["verbose"] = False
        opts[
            "ipopt.linear_solver"] = "mumps"  # use ma27 for faster and more robust computer. This needs to be installed on the system
        # opts["hessian_approximation"] = "limited-memory"  # Use it if the problem is too big. Yet the accuracy of the hessian is low

        # Creat optimization problem for IPOPT
        nlp = cs.nlpsol('ss_optimization', 'ipopt', prob, opts)
        res = nlp(lbx=dvarlcon, ubx=dvarucon, lbg=glcon, ubg=gucon,
                  x0=cs.vertcat(x0,u0))
        sol = res['x'].full().ravel()
        print("Steady-state complete")
        return sol[0:nx], sol[nx:]

